#pragma once

#include <vector>
#include <random>
#include "torch/all.h"

namespace cv
{
    class Mat;
}

using namespace std;
using cv::Mat;

class LeNetImpl : public torch::nn::Module
{
public:
    LeNetImpl(int classes = 10);
    torch::Tensor forward(torch::Tensor x);
    vector<pair<int, float>> predict(torch::Tensor image);

private:
    torch::nn::Sequential features;
    torch::nn::Sequential classifier;
};

TORCH_MODULE(LeNet);

class DataLoader
{
public:
    DataLoader(const string& sampleFile, const string& labelFile);
    tuple<torch::Tensor, torch::Tensor> batch(int size = 100);
    tuple<torch::Tensor, torch::Tensor> all() const;

private:
    void loadImages(const string& file);
    void loadLabels(const string& file);
    int reverse(int x);

private:
    torch::Tensor inputs;
    torch::Tensor labels;
    mt19937 mt;
};



